
package sf.r2dbc.dao;

import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import org.springframework.data.r2dbc.dialect.DialectResolver;
import org.springframework.data.r2dbc.dialect.R2dbcDialect;
import org.springframework.r2dbc.core.DatabaseClient;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import sf.database.dao.DBContext;
import sf.r2dbc.connection.R2dbcConnection;
import sf.r2dbc.sql.R2dbcUtils;

import java.util.function.Function;

/**
 * r2dbc的sorm适配
 * 该适配兼容jdbc模式,可以使用?占位符,执行时会自动替换?占位符.
 * @see org.springframework.data.r2dbc.core.R2dbcEntityTemplate
 */
public class R2dbcTemplate extends R2dbcClient {
    protected DatabaseClient databaseClient;

    public R2dbcTemplate(ConnectionFactory connectionFactory) {
        super(connectionFactory);
        Assert.notNull(connectionFactory, "ConnectionFactory must not be null");

        R2dbcDialect r2dbcDialect = DialectResolver.getDialect(connectionFactory);
        dbDialect = R2dbcUtils.getDBDialect(connectionFactory);
        this.databaseClient = DatabaseClient.builder().connectionFactory(connectionFactory)
                .bindMarkers(r2dbcDialect.getBindMarkersFactory()).build();

    }

    public R2dbcTemplate(DatabaseClient databaseClient) {
        super(databaseClient.getConnectionFactory());
        Assert.notNull(databaseClient, "DatabaseClient must not be null");
        this.databaseClient = databaseClient;
        dbDialect = R2dbcUtils.getDBDialect(databaseClient.getConnectionFactory());
    }

    @Override
    public <T> Flux<T> fluxFunc(Function<Connection, Flux<T>> action) {
        DBContext dbContext = getDBContext();
        return databaseClient.inConnectionMany(connection -> {
            R2dbcConnection r2dbcConnection = getR2dbcConnection(connection, dbContext);
            return action.apply(r2dbcConnection);
        });
    }

    @Override
    public <T> Mono<T> monoFunc(Function<Connection, Mono<T>> action) {
        DBContext dbContext = getDBContext();
        return databaseClient.inConnection(connection -> {
            R2dbcConnection r2dbcConnection = getR2dbcConnection(connection, dbContext);
            return action.apply(r2dbcConnection);
        });
    }

    public DatabaseClient getDatabaseClient() {
        return databaseClient;
    }
}
